import numpy
import pandas
#import lmpfrc
from mdprop import Prop
def load_shake():
    shake_list = numpy.fromfile(lmpfrc.lmp_dir + "shake_list.bin", dtype=numpy.int32)
    shake_flag = numpy.fromfile(lmpfrc.lmp_dir + "shake_flag.bin", dtype=numpy.int32)[shake_list]
    shake_atom = numpy.fromfile(lmpfrc.lmp_dir + "shake_atom.bin", dtype=numpy.int32).reshape((-1, 4))[shake_list]
    shake_type = numpy.fromfile(lmpfrc.lmp_dir + "shake_type.bin", dtype=numpy.int32).reshape((-1, 4))[shake_list]
    shake_frame = pandas.DataFrame(numpy.concatenate([shake_atom, shake_flag], axis=1), coulumns=['i', 'j', 'k', 'l', 'flag'])
    types = msys.data.atoms[['id', 'type']]
    shaket = shake_frame
    for x in 'ijkl':
        shaket = shaket.merge(types, left_on=x, right_on='id', how='left').rename(columns={'type': x+'t'}).drop('id', axis=1)
    return shaket

class RigidShake:
    Props = [
        Prop('masses', list, []),
        Prop('types', list, []),
        Prop('bonds', list, []),
        Prop('angles', list, []),
        Prop('maxiter', int, 500),
        Prop('tol', float, 1e-6)
        ]
    def __init__(self, config):
        props = config.runprops.rigidprops
        masses = props.masses
        types = props.types
        bonds = props.bonds
        angles = props.angles
        self.maxiter = props.maxiter
        self.tol = props.tol
        atom_queries = ['(type == "")']
        if masses:
            for mass in masses:
                atom_queries.append('(mass >= %f - 0.1 and mass <= %f + 0.1)' % (mass, mass))
        if types:
            for t in types:
                atom_queries.append('(type == "%s")' % t)
        self.atom_query = ' or '.join(atom_queries)

        self.bond_query_fwd = None
        self.bond_query_rev = None
        if bonds:
            for btype in bonds:
                bond_query_fwd = []
                bond_query_rev = []
                for btype in bondtypes:
                    bond_query_fwd.append('(t0 == "%s" and t1 == "%s")' % btype)
                    bond_query_rev.append('(t0 == "%s" and t1 == "%s")' % btype[::-1])
                self.bond_query_fwd = ' or '.join(bond_query_fwd)
                self.bond_query_rev = ' or '.join(bond_query_rev)
        self.angles_frame = None
        if angles:
            angles_centered = []
            for angle in angles:
                angles_centered.append((angle[1], angle[0], angle[2], 1))
                if angle[0] != angle[2]:
                    angles_centered.append((angle[1], angle[2], angle[0], 1))
    
            self.angles_frame = pandas.DataFrame(angles_centered, columns=['ti', 'tj', 'tk', 'newflag'])

        dt = config.runprops.dt
        self.dtv = dt
        self.dtfsq = dt * dt * (1. / 48.88821291 / 48.88821291)

    def find_clusters(self, atoms, bonds, angles, bond_params, angle_params):
        shake_id = atoms.query(self.atom_query)[['id']]

        bonds_ijp = bonds[['i', 'j', 'pid']].copy()
        i_shake = shake_id.merge(bonds_ijp, left_on='id', right_on='i', how='inner').drop('id', axis=1)
        j_shake = shake_id.merge(bonds_ijp, left_on='id', right_on='j', how='inner').drop('id', axis=1)
        i_shake = i_shake.rename(columns={'i': 'j', 'j': 'i'})

        shake_bonds = pandas.concat([i_shake, j_shake], sort=True, ignore_index=True)

        if self.bond_query_fwd:
            ib_shake = bonds.query(' or '.join(bond_query_rev))[['i', 'j', 'pid']]
            jb_shake = bonds.query(' or '.join(bond_query_fwd))[['i', 'j', 'pid']]
            ib_shake = ib_shake.rename(columns={'i': 'j', 'j': 'i'})
            
            shake_bonds = pandas.concat([shake_bonds, ib_shake, jb_shake], sort=True, ignore_index=True).drop_duplicates()

        shake_bonds = shake_bonds.reset_index().rename(columns={'index': 'bid'})
        shake_bonds['n'] = ''
        shake_groups = shake_bonds.groupby('i')
        shake_bonds.loc[shake_groups.nth(0).bid, 'n'] = 'j'
        shake_bonds.loc[shake_groups.nth(1).bid, 'n'] = 'k'
        shake_bonds.loc[shake_groups.nth(2).bid, 'n'] = 'l'
        shakes = shake_bonds.pivot_table(index='i', columns='n', values='j', fill_value=0).reset_index()
        types = atoms[['id', 'type', 'tid']]
        for x in 'ijkl':
            shakes = shakes.merge(types, left_on=x, right_on='id', how='left').drop('id', axis=1)
            shakes.rename(columns={'type': 't'+x, 'tid': 'tid'+x}, inplace=True)
        shakes['flag'] = 0
        shakes.loc[shakes.j != 0, 'flag'] = 2
        shakes.loc[shakes.k != 0, 'flag'] = 3
        shakes.loc[shakes.l != 0, 'flag'] = 4
        if self.angles_frame is not None:
            possible_angles = shakes.flag == 3
            matched_types = shakes.merge(self.angles_frame, how='left').newflag == 1
            shakes.loc[possible_angles & matched_types, 'flag'] = 1
    
        bond_params_it = bond_params.set_index(['t0', 't1'])
        angle_params_it = angle_params.set_index(['t0', 't1', 't2'])
        shakes['r0ij'] = bond_params_it.reindex(shakes[['ti', 'tj']]).r0.values
        hasik = shakes.flag.isin([1,3, 4])
        shakes.loc[hasik, 'r0ik'] = bond_params_it.reindex(shakes.loc[hasik, ['ti', 'tk']]).r0.values
        hasil = (shakes.flag == 4)
        shakes.loc[hasil, 'r0il'] = bond_params_it.reindex(shakes.loc[hasil, ['ti', 'tl']]).r0.values
        
        hasjk = (shakes.flag == 1)
        thetajk = angle_params_it.reindex(shakes.loc[hasjk, ['tj', 'ti', 'tk']]).theta0.values
        trij = shakes.loc[hasjk].r0ij.values
        trik = shakes.loc[hasjk].r0ik.values
        r0jksq = trij*trij + trik*trik - 2.*numpy.cos(thetajk)*trij*trik
        shakes.loc[hasjk, 'r0jk'] = numpy.sqrt(r0jksq)
        self.shakes_frame = shakes
        shakes0b = shakes.copy()
        shakes0b[['i', 'j', 'k', 'l']]-=1
        self.shakes = shakes0b.to_dict('records')
    def unconstrained_update(self, x, v, f, mass):
        xshake = x + self.dtv * v + f * self.dtfsq / mass.reshape((-1, 1))
        return xshake
    
    def shake(self, x, mass, xshake, f, i, j, r0):
        r = x[i] - x[j]
        s = xshake[i] - xshake[j]
    
        rsq = r.dot(r)
        ssq = s.dot(s)
    
        minvi = 1./mass[i]
        minvj = 1./mass[j]
    
        a = (minvi + minvj) ** 2 * rsq
        b = 2. * (minvi + minvj) * s.dot(r)
        c = ssq - r0 * r0
    
        determ = b*b - 4*a*c
        determ = max(determ, 0)
        lam1 = (-b+numpy.sqrt(determ)) / (2.*a)
        lam2 = (-b-numpy.sqrt(determ)) / (2.*a)
        #print(lam1, lam2)
        lam = lam1 if (abs(lam1) < abs(lam2)) else lam2
    
        lam /= self.dtfsq
    
        f[i] += lam*r
        f[j] -= lam*r
    
    def shake3(self, x, mass, xshake, f, i, j, k, r0ij, r0ik):
        rij = x[i] - x[j]
        rik = x[i] - x[k]
    
        sij = xshake[i] - xshake[j]
        sik = xshake[i] - xshake[k]
        
        rijsq = rij.dot(rij)
        riksq = rik.dot(rik)
    
        sijsq = sij.dot(sij)
        siksq = sik.dot(sik)
    
        minvi = 1. / mass[i]
        minvj = 1. / mass[j]
        minvk = 1. / mass[k]
    
        a11 = 2. * (minvi + minvj) * sij.dot(rij)
        a12 = 2. * minvi * sij.dot(rik)
        a21 = 2. * minvi * sik.dot(rij)
        a22 = 2. * (minvi + minvk) * sik.dot(rik)
    
        determ = a11*a22 - a12*a21
        determinv = 1./determ
        
        a11inv = a22 * determinv
        a12inv = -a12 * determinv
        a21inv = -a21 * determinv
        a22inv = a11 * determinv
    
        rijik = rij.dot(rik)
    
        quad1_jj = (minvi+minvj) ** 2 * rijsq
        quad1_kk = minvi*minvi*riksq
        quad1_jk = 2.*(minvi+minvj)*minvi*rijik
    
        quad2_kk = (minvi+minvk) ** 2 * riksq
        quad2_jj = minvi*minvi*rijsq
        quad2_jk = 2.*(minvi+minvk)*minvi*rijik
    
        lam1 = 0
        lam2 = 0
        
        done = False
        niter = 0
    
        while not done and niter < self.maxiter:
            quad1 = quad1_jj*lam1*lam1 + quad1_kk*lam2*lam2 + quad1_jk*lam1*lam2
            quad2 = quad2_jj*lam1*lam1 + quad2_kk*lam2*lam2 + quad2_jk*lam1*lam2
    
            b1 = r0ij**2 - sijsq - quad1
            b2 = r0ik**2 - siksq - quad2
    
            lam1_new = a11inv*b1 + a12inv*b2
            lam2_new = a21inv*b1 + a22inv*b2
    
            done = abs(lam1_new - lam1) <= self.tol and abs(lam2_new - lam2) <= self.tol
            lam1 = lam1_new
            lam2 = lam2_new
            niter += 1
        lam1 = lam1/self.dtfsq
        lam2 = lam2/self.dtfsq
    
        f[i] += lam1*rij + lam2*rik
        f[j] -= lam1 * rij
        f[k] -= lam2 * rik
    
    def shake4(self, x, mass, xshake, f, i, j, k, l, r0ij, r0ik, r0il):
        rij = x[i] - x[j]
        rik = x[i] - x[k]
        ril = x[i] - x[l]
    
        sij = xshake[i] - xshake[j]
        sik = xshake[i] - xshake[k]
        sil = xshake[i] - xshake[l]
    
        rijsq = rij.dot(rij)
        riksq = rik.dot(rik)
        rilsq = ril.dot(ril)
    
        sijsq = sij.dot(sij)
        siksq = sik.dot(sik)
        silsq = sil.dot(sil)
    
        minvi = 1. / mass[i]
        minvj = 1. / mass[j]
        minvk = 1. / mass[k]
        minvl = 1. / mass[l]
    
        a = numpy.asarray([sij, sik, sil]).dot(numpy.asarray([rij, rik, ril]).T) * minvi
        a += numpy.diagflat([minvj * sij.dot(rij), minvk * sik.dot(rik), minvl * sil.dot(ril)])
    
        a11 = 2. * (minvi + minvj) * sij.dot(rij)
        a12 = 2. * minvi * sij.dot(rik)
        a13 = 2. * minvi * sij.dot(ril)
    
        a21 = 2. * minvi * sik.dot(rij)
        a22 = 2. * (minvi + minvk) * sik.dot(rik)
        a23 = 2. * minvi * sik.dot(ril)
    
        a31 = 2. * minvi * sil.dot(rij)
        a32 = 2. * minvi * sil.dot(rik)
        a33 = 2. * (minvi + minvl) * sil.dot(ril)
        #print(a*2 - numpy.asarray([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]]))
        determ = a11*a22*a33 + a12*a23*a31 + a13*a21*a32 - a11*a23*a32 - a12*a21*a33 - a13*a22*a31
        determinv = 1./determ
    
        a11inv =  determinv * (a22*a33 - a23*a32)
        a12inv = -determinv * (a12*a33 - a13*a32)
        a13inv =  determinv * (a12*a23 - a13*a22)
        a21inv = -determinv * (a21*a33 - a23*a31)
        a22inv =  determinv * (a11*a33 - a13*a31)
        a23inv = -determinv * (a11*a23 - a13*a21)
        a31inv =  determinv * (a21*a32 - a22*a31)
        a32inv = -determinv * (a11*a32 - a12*a31)
        a33inv =  determinv * (a11*a22 - a12*a21)
        #print(numpy.linalg.inv(a*2) - numpy.asarray([[a11inv, a12inv, a13inv], [a21inv, a22inv, a23inv], [a31inv, a32inv, a33inv]]))
    
        rijik = rij.dot(rik)
        rijil = rij.dot(ril)
        rikil = rik.dot(ril)
    
        quad1_jj = (minvi+minvj)*(minvi+minvj) * rijsq;
        quad1_kk = minvi*minvi * riksq;
        quad1_ll = minvi*minvi * rilsq;
        quad1_jk = 2.0 * (minvi+minvj)*minvi * rijik;
        quad1_jl = 2.0 * (minvi+minvj)*minvi * rijil;
        quad1_kl = 2.0 * minvi*minvi * rikil;
       
        quad2_jj = minvi*minvi * rijsq;
        quad2_kk = (minvi+minvk)*(minvi+minvk) * riksq;
        quad2_ll = minvi*minvi * rilsq;
        quad2_jk = 2.0 * (minvi+minvk)*minvi * rijik;
        quad2_jl = 2.0 * minvi*minvi * rijil;
        quad2_kl = 2.0 * (minvi+minvk)*minvi * rikil;
       
        quad3_jj = minvi*minvi * rijsq;
        quad3_kk = minvi*minvi * riksq;
        quad3_ll = (minvi+minvl)*(minvi+minvl) * rilsq;
        quad3_jk = 2.0 * minvi*minvi * rijik;
        quad3_jl = 2.0 * (minvi+minvl)*minvi * rijil;
        quad3_kl = 2.0 * (minvi+minvl)*minvi * rikil;
    
        lam1 = 0
        lam2 = 0
        lam3 = 0
    
        done = False
        niter = 0
    
        while not done and niter < self.maxiter:
            quad1 = quad1_jj*lam1*lam1 + quad1_kk*lam2*lam2 + quad1_ll*lam3*lam3 + quad1_jk*lam1*lam2 + quad1_jl*lam1*lam3 + quad1_kl*lam2*lam3
            quad2 = quad2_jj*lam1*lam1 + quad2_kk*lam2*lam2 + quad2_ll*lam3*lam3 + quad2_jk*lam1*lam2 + quad2_jl*lam1*lam3 + quad2_kl*lam2*lam3
            quad3 = quad3_jj*lam1*lam1 + quad3_kk*lam2*lam2 + quad3_ll*lam3*lam3 + quad3_jk*lam1*lam2 + quad3_jl*lam1*lam3 + quad3_kl*lam2*lam3
    
            b1 = r0ij**2 - sijsq - quad1
            b2 = r0ik**2 - siksq - quad2
            b3 = r0il**2 - silsq - quad3
    
            lam1_new = a11inv*b1 + a12inv*b2 + a13inv*b3
            lam2_new = a21inv*b1 + a22inv*b2 + a23inv*b3
            lam3_new = a31inv*b1 + a32inv*b2 + a33inv*b3
    
            done = abs(lam1_new - lam1) <= self.tol and abs(lam2_new - lam2) <= self.tol and abs(lam3_new - lam3) <= self.tol
            lam1 = lam1_new
            lam2 = lam2_new
            lam3 = lam3_new
            done = done or (abs(lam1) > 1e150 or abs(lam2) > 1e150 or abs(lam3) > 1e150)
            niter += 1
        lam1 = lam1/self.dtfsq
        lam2 = lam2/self.dtfsq
        lam3 = lam3/self.dtfsq
        f[i] += lam1*rij + lam2*rik + lam3*ril
        f[j] -= lam1 * rij
        f[k] -= lam2 * rik
        f[l] -= lam3 * ril
    
    def shake3angle(self, x, mass, xshake, f, i, j, k, r0ij, r0ik, r0jk):
        rij = x[i] - x[j]
        rik = x[i] - x[k]
        rjk = x[j] - x[k]
    
        sij = xshake[i] - xshake[j]
        sik = xshake[i] - xshake[k]
        sjk = xshake[j] - xshake[k]
    
        # if i == 1288:
        #     print(x[i], xshake[i], x[i] - xshake[i], dtfsq)
        rijsq = rij.dot(rij)
        riksq = rik.dot(rik)
        rjksq = rjk.dot(rjk)
    
        sijsq = sij.dot(sij)
        siksq = sik.dot(sik)
        sjksq = sjk.dot(sjk)
        # if i == 1288:
        #     print(rijsq, riksq, rjksq)
        #     print(sijsq, siksq, sjksq)
        minvi = 1. / mass[i]
        minvj = 1. / mass[j]
        minvk = 1. / mass[k]
        #print(minvi, minvj, minvk)
        a11 =  2. * (minvi + minvj) * sij.dot(rij)
        a12 =  2. * minvi * sij.dot(rik)
        a13 = -2. * minvj * sij.dot(rjk)
        a21 =  2. * minvi * sik.dot(rij)
        a22 =  2. * (minvi + minvk) * sik.dot(rik)
        a23 =  2. * minvk * sik.dot(rjk)
        a31 = -2. * minvj * sjk.dot(rij)
        a32 =  2. * minvk * sjk.dot(rik)
        a33 =  2. * (minvj + minvk) * sjk.dot(rjk)
        # if i == 1288:
        #     print(rij, rik, rjk)
        determ = a11*a22*a33 + a12*a23*a31 + a13*a21*a32 - a11*a23*a32 - a12*a21*a33 - a13*a22*a31 #a11*a22 - a12*a21
        determinv = 1./determ
        # if i == 1288:
        #     print(determ)
        #print(determ)
        a11inv =  determinv * (a22*a33 - a23*a32)
        a12inv = -determinv * (a12*a33 - a13*a32)
        a13inv =  determinv * (a12*a23 - a13*a22)
        a21inv = -determinv * (a21*a33 - a23*a31)
        a22inv =  determinv * (a11*a33 - a13*a31)
        a23inv = -determinv * (a11*a23 - a13*a21)
        a31inv =  determinv * (a21*a32 - a22*a31)
        a32inv = -determinv * (a11*a32 - a12*a31)
        a33inv =  determinv * (a11*a22 - a12*a21)
    
        rijik = rij.dot(rik)
        rijjk = rij.dot(rjk)
        rikjk = rik.dot(rjk)
    
        quad1_ijij = (minvi+minvj)*(minvi+minvj) * rijsq
        quad1_ikik = minvi*minvi * riksq
        quad1_jkjk = minvj*minvj * rjksq
        quad1_ijik = 2.0 * (minvi+minvj)*minvi * rijik
        quad1_ijjk = - 2.0 * (minvi+minvj)*minvj * rijjk
        quad1_ikjk = - 2.0 * minvi*minvj * rikjk
       
        quad2_ijij = minvi*minvi * rijsq
        quad2_ikik = (minvi+minvk)*(minvi+minvk) * riksq
        quad2_jkjk = minvk*minvk * rjksq
        quad2_ijik = 2.0 * (minvi+minvk)*minvi * rijik
        quad2_ijjk = 2.0 * minvi*minvk * rijjk
        quad2_ikjk = 2.0 * (minvi+minvk)*minvk * rikjk
       
        quad3_ijij = minvj*minvj * rijsq
        quad3_ikik = minvk*minvk * riksq
        quad3_jkjk = (minvj+minvk)*(minvj+minvk) * rjksq
        quad3_ijik = - 2.0 * minvj*minvk * rijik
        quad3_ijjk = - 2.0 * (minvj+minvk)*minvj * rijjk
        quad3_ikjk = 2.0 * (minvj+minvk)*minvk * rikjk
    
        lamij = 0
        lamik = 0
        lamjk = 0
    
        done = False
        niter = 0
    
        while not done and niter < self.maxiter:
            quad1 = quad1_ijij*lamij*lamij + quad1_ikik*lamik*lamik + quad1_jkjk*lamjk*lamjk + quad1_ijik*lamij*lamik + quad1_ijjk*lamij*lamjk + quad1_ikjk*lamik*lamjk
        
            quad2 = quad2_ijij*lamij*lamij + quad2_ikik*lamik*lamik + quad2_jkjk*lamjk*lamjk + quad2_ijik*lamij*lamik + quad2_ijjk*lamij*lamjk + quad2_ikjk*lamik*lamjk
        
            quad3 = quad3_ijij*lamij*lamij + quad3_ikik*lamik*lamik + quad3_jkjk*lamjk*lamjk + quad3_ijik*lamij*lamik + quad3_ijjk*lamij*lamjk + quad3_ikjk*lamik*lamjk
        
            b1 = r0ij**2 - sijsq - quad1
            b2 = r0ik**2 - siksq - quad2
            b3 = r0jk**2 - sjksq - quad3
            # if i == 1288:
            #     print(b1, b2, b3)
            lamij_new = a11inv*b1 + a12inv*b2 + a13inv*b3
            lamik_new = a21inv*b1 + a22inv*b2 + a23inv*b3
            lamjk_new = a31inv*b1 + a32inv*b2 + a33inv*b3
    
            # if i == 1288:
            #     print(lamij_new - lamij, lamik_new - lamik, lamjk_new - lamjk, niter, tol)
            #     print(lamij_new, lamik_new, lamjk_new)
            #     print(lamij, lamik, lamjk)
            done = abs(lamij_new - lamij) <= self.tol and abs(lamik_new - lamik) <= self.tol and abs(lamjk_new - lamjk) <= self.tol
            lamij = lamij_new
            lamik = lamik_new
            lamjk = lamjk_new
            done = done or (abs(lamij) > 1e150 or abs(lamik) > 1e150 or abs(lamjk) > 1e150)
            niter += 1
        lamij = lamij/self.dtfsq
        lamik = lamik/self.dtfsq
        lamjk = lamjk/self.dtfsq
    
        f[i] +=  lamij*rij + lamik*rik
        f[j] += -lamij*rij + lamjk*rjk
        f[k] += -lamik*rik - lamjk*rjk
    
    def post_force(self, x, v, f, mass):
        xshake = self.unconstrained_update(x, v, f, mass)
        for sk in self.shakes:
            if sk['flag'] == 2:
                self.shake(x, mass, xshake, f, sk['i'], sk['j'], sk['r0ij'])
            elif sk['flag'] == 3:
                self.shake3(x, mass, xshake, f, sk['i'], sk['j'], sk['k'], sk['r0ij'], sk['r0ik'])
            elif sk['flag'] == 4:
                self.shake4(x, mass, xshake, f, sk['i'], sk['j'], sk['k'], sk['l'], sk['r0ij'], sk['r0ik'], sk['r0il'])
            elif sk['flag'] == 1:
                self.shake3angle(x, mass, xshake, f, sk['i'], sk['j'], sk['k'], sk['r0ij'], sk['r0ik'], sk['r0jk'])
        
    def correct_coordinates(self, x, mass):
        ftmp = numpy.zeros(x.shape)
        vtmp = numpy.zeros(x.shape)

        self.post_force(x, vtmp, ftmp, mass)

        dtfmsq = self.dtfsq / mass
        x[:] = x + ftmp * dtfmsq.reshape((-1,1))

        return x
    
    def shake_end_of_step(self, x, v, f, mass):
        self.post_force(x, v, f, mass)

    def shake_setup(self, x, v, f, mass):
        self.dtfsq *= 0.5

        self.correct_coordinates(x, mass)
        self.shake_end_of_step(x, v, f, mass)

        self.dtfsq *= 2
